library(tidyverse)
library(brms)
library(tidybayes)
library(bayesplot)
library(cowplot)
library(scales)
library(hexbin)
library(glue)

theme_set(theme_light())

source(glue("{params$common_dir_str}/brms_model.R"))
source(glue("{params$common_dir_str}/simulation.R"))

load obs

obs_only <- 
  read_csv(glue("{params$model_dir_str}/data/stimulation_obvs.csv")) %>%
  mutate(subj = as_factor(subj),
         obs_degree = error,
         error = obs_degree * (pi/180))
## Parsed with column specification:
## cols(
##   subj = col_double(),
##   subj_index = col_double(),
##   stimulation = col_double(),
##   error = col_double()
## )
# obs_only <- sim_data %>%
#   unnest(subj_obs) %>%
#   select(subj_index = subj, 
#          error = obs_radian,
#          obs_degree, 
#          stimulation)

peek at data

pre
obs_only %>% 
  filter(stimulation == 0) %>%
  ggplot(aes(x = obs_degree)) +
  geom_histogram(binwidth = 10, aes(y=..density..)) + 
  geom_rug() + 
  geom_density(aes(y=..density..)) +  
  facet_wrap(vars(subj), ncol = 1)

post
obs_only %>% 
  filter(stimulation == 1) %>%
  ggplot(aes(x = obs_degree)) +
  geom_histogram(binwidth = 10, aes(y=..density..)) + 
  geom_rug() + 
  geom_density(aes(y=..density..)) +  
  facet_wrap(vars(subj), ncol = 1)

fit brms

source(glue("{params$model_dir_str}/model_prior.R"))

print(bprior_full)
##              prior class        coef group resp   dpar nlpar bound
## 1 normal(3.8, 0.4)     b   intercept            circSD            
## 2   normal(0, 0.4)     b stimulation            circSD            
## 3  normal(0, 0.25)    sd   Intercept  subj      circSD            
## 4  normal(0, 0.25)    sd stimulation  subj      circSD            
## 5   normal(0, 1.5)     b   intercept             theta            
## 6     normal(0, 1)     b stimulation             theta            
## 7   normal(0, 0.5)    sd   Intercept  subj       theta            
## 8   normal(0, 0.5)    sd stimulation  subj       theta
iter = 4000
warmup = 2000
cores = 4
chains = 4
n_post_samples = (iter - warmup) * chains

model_fit <- brm(bform_full, obs_only, family = vm_uniform_mix, prior = bprior_full, stanvars = stanvars,
                 sample_prior = "yes",
                 warmup = warmup, iter = iter, cores = cores, chains = chains, 
                 control = list(adapt_delta = 0.99), inits = 0, 
                 file = glue("{params$save_dir_str}/obs_model_fit"))
## Compiling the C++ model
## Start sampling
print(model_fit)
##  Family: vm_uniform_mix 
##   Links: mu = identity; circSD = log; theta = logit; a = identity; b = identity 
## Formula: error ~ 0 
##          circSD ~ 0 + intercept + stimulation + (1 + stimulation || subj)
##          theta ~ 0 + intercept + stimulation + (1 + stimulation || subj)
##          a = -3.14
##          b = 3.14
##    Data: obs_only (Number of observations: 504) 
## Samples: 4 chains, each with iter = 4000; warmup = 2000; thin = 1;
##          total post-warmup samples = 8000
## 
## Group-Level Effects: 
## ~subj (Number of levels: 2) 
##                        Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## sd(circSD_Intercept)       0.21      0.14     0.01     0.54 1.00     3479
## sd(circSD_stimulation)     0.22      0.15     0.01     0.55 1.00     4366
## sd(theta_Intercept)        0.54      0.27     0.10     1.18 1.00     4560
## sd(theta_stimulation)      0.33      0.26     0.01     0.97 1.00     4779
##                        Tail_ESS
## sd(circSD_Intercept)       3271
## sd(circSD_stimulation)     2945
## sd(theta_Intercept)        2853
## sd(theta_stimulation)      3155
## 
## Population-Level Effects: 
##                    Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## circSD_intercept       3.36      0.20     3.01     3.81 1.00     4739
## circSD_stimulation     0.24      0.23    -0.24     0.69 1.00     5548
## theta_intercept        0.13      0.45    -0.76     1.03 1.00     3444
## theta_stimulation     -0.18      0.42    -1.00     0.67 1.00     5094
##                    Tail_ESS
## circSD_intercept       3873
## circSD_stimulation     4968
## theta_intercept        3829
## theta_stimulation      4900
## 
## Family Specific Parameters: 
##                      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## b_circSD_intercept       3.36      0.20     3.01     3.81 1.00     4739
## b_circSD_stimulation     0.24      0.23    -0.24     0.69 1.00     5548
## b_theta_intercept        0.13      0.45    -0.76     1.03 1.00     3444
## b_theta_stimulation     -0.18      0.42    -1.00     0.67 1.00     5094
##                      Tail_ESS
## b_circSD_intercept       3873
## b_circSD_stimulation     4968
## b_theta_intercept        3829
## b_theta_stimulation      4900
## 
## Samples were drawn using sampling(NUTS). For each parameter, Eff.Sample 
## is a crude measure of effective sample size, and Rhat is the potential 
## scale reduction factor on split chains (at convergence, Rhat = 1).

fit check

divergences

#check neff and rhat and divergences
np <- nuts_params(model_fit)
rhat <- brms::rhat(model_fit)
neff_rat <- neff_ratio(model_fit)

np %>% 
  filter(Parameter == "divergent__") %>%
  summarise(n_div = sum(Value))
##   n_div
## 1     0

rhat

mcmc_rhat(rhat) + yaxis_text(hjust = 1) + scale_x_continuous(breaks = pretty_breaks(6))
## Scale for 'x' is already present. Adding another scale for 'x', which
## will replace the existing scale.

neff ratio

mcmc_neff(neff_rat) + yaxis_text(hjust = 1)

trace plots

mcmc_trace(as.array(model_fit$fit))

other

plot(model_fit, ask = FALSE)

plot posteriors

arrange samples

# compute summaries for plot

group_level_samples <- 
  spread_draws(model_fit, `(b|sd)_.*`, regex = TRUE) %>%
  mutate(
         # group level parameters
         circSD_pre_mean  = exp(b_circSD_intercept),
         circSD_post_mean = exp(b_circSD_intercept + b_circSD_stimulation),
         circSD_ES_mean   = circSD_post_mean - circSD_pre_mean,
         pMem_pre_mean    = inv_logit(b_theta_intercept),
         pMem_post_mean   = inv_logit(b_theta_intercept + b_theta_stimulation),
         pMem_ES_mean     = pMem_post_mean - pMem_pre_mean,
         # predicitve dist for group level parameters
         circSD_pre_pred  = exp(rnorm(n(), b_circSD_intercept, sd_subj__circSD_Intercept)),
         circSD_post_pred = exp(rnorm(n(), b_circSD_intercept, sd_subj__circSD_Intercept) + 
                                rnorm(n(), b_circSD_stimulation, sd_subj__circSD_stimulation)),
         circSD_ES_pred   = circSD_post_pred - circSD_pre_pred,
         pMem_pre_pred    = inv_logit(rnorm(n(), b_theta_intercept, sd_subj__theta_Intercept)),
         pMem_post_pred   = inv_logit(rnorm(n(), b_theta_intercept, sd_subj__theta_Intercept) +
                                    rnorm(n(), b_theta_stimulation, sd_subj__theta_stimulation)),
         pMem_ES_pred     = pMem_post_pred - pMem_pre_pred
         ) %>% 
  select(-contains("b_"), -contains("sd_subj")) %>%
  pivot_longer(-contains("."), names_to = c("param", "stat"), names_pattern = "(.*)_(.*)", values_to = "value") %>%
  pivot_wider(names_from = stat, values_from = value)
  
group_level_summary <- 
  group_level_samples %>%
  group_by(param) %>%
  median_qi(.width = c(.5, .8, .95))
  #median_qi(.width = c(.90, .95))



circSD_subj_samples <- 
  model_fit %>%
  spread_draws(b_circSD_intercept, b_circSD_stimulation, r_subj__circSD[subj, term]) %>%
  ungroup() %>%
  pivot_wider(names_from = term, values_from = r_subj__circSD, names_prefix = "offset_") %>%
  mutate(subj = subj,
        circSD_pre = exp(b_circSD_intercept + offset_Intercept),
        circSD_post = exp(b_circSD_intercept + offset_Intercept + b_circSD_stimulation + offset_stimulation),
        circSD_ES = circSD_post - circSD_pre) %>%
  select(-c(b_circSD_intercept, offset_Intercept, b_circSD_stimulation, offset_stimulation)) %>%
  pivot_longer(contains("circSD"), names_to = "param", values_to = "value") 

circSD_subj_summary <- 
  circSD_subj_samples %>%
  group_by(subj, param) %>%
  median_qi(.width = c(.90, .95))
## Warning: unnest() has a new interface. See ?unnest for details.
## Try `df %>% unnest(c(.lower, .upper))`, with `mutate()` if needed

## Warning: unnest() has a new interface. See ?unnest for details.
## Try `df %>% unnest(c(.lower, .upper))`, with `mutate()` if needed
pMem_subj_samples <- 
  model_fit %>%
  spread_draws(b_theta_intercept, b_theta_stimulation, r_subj__theta[subj, term]) %>%
  ungroup() %>%
  pivot_wider(names_from = term, values_from = r_subj__theta, names_prefix = "offset_") %>%
  mutate(subj = subj,
            pMem_pre = inv_logit(b_theta_intercept + offset_Intercept),
            pMem_post = inv_logit(b_theta_intercept + offset_Intercept + b_theta_stimulation + offset_stimulation),
            pMem_ES = pMem_post - pMem_pre) %>%
  select(-c(b_theta_intercept, offset_Intercept, b_theta_stimulation, offset_stimulation)) %>%
  pivot_longer(contains("pMem"), names_to = "param", values_to = "value") 

pMem_subj_summary <- 
  pMem_subj_samples %>%
  group_by(subj, param) %>%
  median_qi(.width = c(.90, .95))
## Warning: unnest() has a new interface. See ?unnest for details.
## Try `df %>% unnest(c(.lower, .upper))`, with `mutate()` if needed

## Warning: unnest() has a new interface. See ?unnest for details.
## Try `df %>% unnest(c(.lower, .upper))`, with `mutate()` if needed
group_level_samples %>%
  select(-pred) %>%
  group_by(param) %>%
  median_qi(.width = c(.95))
## Warning: unnest() has a new interface. See ?unnest for details.
## Try `df %>% unnest(c(.lower, .upper))`, with `mutate()` if needed
## # A tibble: 6 x 7
##   param           mean  .lower .upper .width .point .interval
##   <chr>          <dbl>   <dbl>  <dbl>  <dbl> <chr>  <chr>    
## 1 circSD_ES     8.10    -7.07  28.6     0.95 median qi       
## 2 circSD_post  36.5     21.9   65.9     0.95 median qi       
## 3 circSD_pre   28.5     20.3   45.0     0.95 median qi       
## 4 pMem_ES      -0.0436  -0.233  0.155   0.95 median qi       
## 5 pMem_post     0.483    0.240  0.753   0.95 median qi       
## 6 pMem_pre      0.530    0.319  0.738   0.95 median qi

group level posteriors

circSD_p1 <- group_level_samples %>% 
  filter(str_detect(param, "circSD")) %>%
  ggplot() + 
  # posterior dist + interval for group mean
  geom_halfeyeh(aes(y = param, x = mean), .width = c(.90, .95), position = position_nudge(y = 0.15)) + 
  # posterior predictive distribution for group means
  stat_intervalh(aes(y = param, x = pred), .width = c(.5, .8, .95)) +
  # posterior medians for each parameter estimate per subj
  geom_point(data = circSD_subj_summary, aes(y = param, x = value), size = 2) +
  # decorations
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  coord_cartesian(xlim=c(-50, 150)) +
  labs(subtitle = "circSD: group level mean posterior (median, 90%, 95% interval), \nsubject posterior medians, \ncondition predictive dist of subjects", 
       x = "circSD", 
       color = "interval")

# group level pMem pre, post and ES plot

pMem_p1 <- group_level_samples %>% 
  filter(str_detect(param, "pMem")) %>%
  ggplot() + 
  # posterior dist + interval for group mean
  geom_halfeyeh(aes(y = param, x = mean), .width = c(.90, .95), position = position_nudge(y = 0.15)) + 
  # posterior predictive distribution for group means
  stat_intervalh(aes(y = param, x = pred), .width = c(.5, .8, .95)) +
  # posterior medians for each parameter estimate per subj
  geom_point(data = pMem_subj_summary, aes(y = param, x = value), size = 2) +
  # decorations
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  #coord_cartesian(xlim=c(-50, 150)) +
  labs(subtitle = "pMem: group level mean posterior (median, 90%, 95% interval), \nsubject posterior medians, \ncondition predictive dist of subjects", 
       x = "pMem", 
       color = "interval")

plot_grid(circSD_p1, pMem_p1, align = "hv", ncol = 1)

circSD pre, post, ES

# circSD pre: group level posteriors and subject posteriors

circSD_p2 <- 
  ggplot() + 
  # plot group mean circSD_pre posterior and subject circSD_pre posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "circSD_pre")) %>%
                              select(-pred, value = mean)
                            ,
                              circSD_subj_samples %>%
                              filter(str_detect(param, "circSD_pre")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot pre condition group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "circSD_pre")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject circSD_pre posterior medians in the prediction band
  geom_point(data = circSD_subj_summary %>% filter(param == "circSD_pre"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y = -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "circSD_pre: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \ncondition predictive dist of subjects",
       x = "circSD",
       color = "interval")



circSD_p3 <-   
  ggplot() + 
  # plot group mean circSD_post posterior and subject circSD_post posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "circSD_post")) %>%
                              select(-pred, value = mean)
                            ,
                              circSD_subj_samples %>%
                              filter(str_detect(param, "circSD_post")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot post condition group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "circSD_post")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject circSD_post posterior medians in the prediction band
  geom_point(data = circSD_subj_summary %>% filter(param == "circSD_post"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y =  -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "circSD_post: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \ncondition predictive dist of subjects",
       x = "circSD",
       color = "interval")


circSD_p4 <- 
  ggplot() + 
  # plot group mean circSD_ES posterior and subject circSD_ES posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "circSD_ES")) %>%
                              select(-pred, value = mean)
                            ,
                              circSD_subj_samples %>%
                              filter(str_detect(param, "circSD_ES")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot ES  group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "circSD_ES")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject circSD_ES posterior medians in the prediction band
  geom_point(data = circSD_subj_summary %>% filter(param == "circSD_ES"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y =  -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "circSD_ES: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \nES predictive dist of subjects",
       x = "detla circSD",
       color = "interval")


plot_grid(circSD_p2, circSD_p3, circSD_p4, ncol = 1, align = "hv")

pMem pre, post, ES

# pMem_ pre: group level posteriors and subject posteriors

pMem_p2 <- 
  ggplot() + 
  # plot group mean pMem__pre posterior and subject cpMem__pre posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "pMem_pre")) %>%
                              select(-pred, value = mean)
                            ,
                              pMem_subj_samples %>%
                              filter(str_detect(param, "pMem_pre")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot pre condition group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "pMem_pre")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject pMem_pre posterior medians in the prediction band
  geom_point(data = pMem_subj_summary %>% filter(param == "pMem_pre"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y = -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "pMem_pre: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \ncondition predictive dist of subjects",
       x = "pMem",
       color = "interval")



pMem_p3 <-   
  ggplot() + 
  # plot group mean pMem_post posterior and subject pMem_post posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "pMem_post")) %>%
                              select(-pred, value = mean)
                            ,
                              pMem_subj_samples %>%
                              filter(str_detect(param, "pMem_post")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot post condition group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "pMem_post")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject pMem_post posterior medians in the prediction band
  geom_point(data = pMem_subj_summary %>% filter(param == "pMem_post"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y =  -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "pMem_post: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \ncondition predictive dist of subjects",
       x = "pMem",
       color = "interval")


pMem_p4 <- 
  ggplot() + 
  # plot group mean pMem_ES posterior and subject pMem_ES posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "pMem_ES")) %>%
                              select(-pred, value = mean)
                            ,
                              pMem_subj_samples %>%
                              filter(str_detect(param, "pMem_ES")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot ES  group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "pMem_ES")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject pMem_ES posterior medians in the prediction band
  geom_point(data = pMem_subj_summary %>% filter(param == "pMem_ES"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y =  -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "pMem_ES: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \nES predictive dist of subjects",
       x = "detla pMem",
       color = "interval")
plot_grid(pMem_p2, pMem_p3, pMem_p4, ncol = 1, align = "hv")

group level joint posteriors

group_level_samples %>%
  pivot_wider(id_cols = contains("."), names_from = param, values_from = mean) %>%
  select(-contains(".")) %>%
  mcmc_pairs(off_diag_fun = "hex")
## Warning: Only one chain in 'x'. This plot is more useful with multiple
## chains.

plot posteriors w/ priors

# compute summaries for plot with priors

group_level_samples <- 
  spread_draws(model_fit, `(prior_)?(b|sd)_.*`, regex = TRUE) %>%
  mutate(
         # prior
         # group level parameters
         prior_circSD_pre_mean  = exp(prior_b_circSD_intercept),
         prior_circSD_post_mean = exp(prior_b_circSD_intercept + prior_b_circSD_stimulation),
         prior_circSD_ES_mean   = prior_circSD_post_mean - prior_circSD_pre_mean,
         prior_pMem_pre_mean    = inv_logit(prior_b_theta_intercept),
         prior_pMem_post_mean   = inv_logit(prior_b_theta_intercept + prior_b_theta_stimulation),
         prior_pMem_ES_mean     = prior_pMem_post_mean - prior_pMem_pre_mean,   
         # predicitve dist for group level parameters
         prior_circSD_pre_pred  = exp(rnorm(n(), prior_b_circSD_intercept, prior_sd_subj__circSD_Intercept)),
         prior_circSD_post_pred = exp(rnorm(n(), prior_b_circSD_intercept, prior_sd_subj__circSD_Intercept) + 
                                rnorm(n(), prior_b_circSD_stimulation, prior_sd_subj__circSD_stimulation)),
         prior_circSD_ES_pred   = prior_circSD_post_pred - prior_circSD_pre_pred,
         prior_pMem_pre_pred    = inv_logit(rnorm(n(), prior_b_theta_intercept, prior_sd_subj__theta_Intercept)),
         prior_pMem_post_pred   = inv_logit(rnorm(n(), prior_b_theta_intercept, prior_sd_subj__theta_Intercept) +
                                    rnorm(n(), prior_b_theta_stimulation, prior_sd_subj__theta_stimulation)),
         prior_pMem_ES_pred     = prior_pMem_post_pred - prior_pMem_pre_pred,
         
         # posteriors
         # group level parameters
         circSD_pre_mean  = exp(b_circSD_intercept),
         circSD_post_mean = exp(b_circSD_intercept + b_circSD_stimulation),
         circSD_ES_mean   = circSD_post_mean - circSD_pre_mean,
         pMem_pre_mean    = inv_logit(b_theta_intercept),
         pMem_post_mean   = inv_logit(b_theta_intercept + b_theta_stimulation),
         pMem_ES_mean     = pMem_post_mean - pMem_pre_mean,
         # predicitve dist for group level parameters
         circSD_pre_pred  = exp(rnorm(n(), b_circSD_intercept, sd_subj__circSD_Intercept)),
         circSD_post_pred = exp(rnorm(n(), b_circSD_intercept, sd_subj__circSD_Intercept) + 
                                rnorm(n(), b_circSD_stimulation, sd_subj__circSD_stimulation)),
         circSD_ES_pred   = circSD_post_pred - circSD_pre_pred,
         pMem_pre_pred    = inv_logit(rnorm(n(), b_theta_intercept, sd_subj__theta_Intercept)),
         pMem_post_pred   = inv_logit(rnorm(n(), b_theta_intercept, sd_subj__theta_Intercept) +
                                    rnorm(n(), b_theta_stimulation, sd_subj__theta_stimulation)),
         pMem_ES_pred     = pMem_post_pred - pMem_pre_pred
         ) %>% 
  select(-contains("b_"), -contains("sd_subj")) %>%
  pivot_longer(-contains("."), names_to = c( "param", "stat"), names_pattern = "(.*)_(.*)", values_to = "value") %>%
  pivot_wider(names_from = stat, values_from = value)

group level posteriors

circSD_p1_wPrior <- 
  group_level_samples %>% 
  filter(str_detect(param, "circSD")) %>% 
  mutate(param = fct_relevel(param, c("prior_circSD_ES", "circSD_ES", "prior_circSD_pre", "circSD_pre", "prior_circSD_post", "circSD_post"))) %>%
  ggplot() + 
  
  # posterior dist + interval for group mean
  geom_halfeyeh(aes(y = param, x = mean), .width = c(.90, .95), position = position_nudge(y = 0.15)) + 
  # posterior predictive distribution for group means
  stat_intervalh(aes(y = param, x = pred), .width = c(.5, .8, .95)) +
  # posterior medians for each parameter estimate per subj
  geom_point(data = circSD_subj_summary, aes(y = param, x = value), size = 2) +
  
  # prior dist + interval for group mean
  #geom_halfeyeh(data = group_level_prior_param_samples, aes(y = param, x = values), .width = c(.90, .95)) +
  
  # decorations
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  coord_cartesian(xlim=c(-50, 150)) +
  labs(subtitle = "circSD: group level mean prior + posterior (median, 90%, 95% interval), \nsubject posterior medians, \ncondition predictive dist of subjects", 
       x = "circSD", 
       color = "interval")

# group level pMem pre, post and ES plot

pMem_p1_wPrior <- 
  group_level_samples %>% 
  filter(str_detect(param, "pMem")) %>%
  mutate(param = fct_relevel(param, c("prior_pMem_ES", "pMem_ES", "prior_pMem_pre", "pMem_pre", "prior_pMem_post", "pMem_post"))) %>%
  ggplot() + 
  
  # posterior dist + interval for group mean
  geom_halfeyeh(aes(y = param, x = mean), .width = c(.90, .95), position = position_nudge(y = 0.15)) + 
  # posterior predictive distribution for group means
  stat_intervalh(aes(y = param, x = pred), .width = c(.5, .8, .95)) +
  # posterior medians for each parameter estimate per subj
  geom_point(data = pMem_subj_summary, aes(y = param, x = value), size = 2) +
  
  # decorations
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  #coord_cartesian(xlim=c(-50, 150)) +
  labs(subtitle = "pMem: group level mean prior + posterior (median, 90%, 95% interval), \nsubject posterior medians, \ncondition predictive dist of subjects", 
       x = "pMem", 
       color = "interval")

plot_grid(circSD_p1_wPrior, pMem_p1_wPrior, align = "hv", ncol = 1)

posterior predictive plot of errors

# plotting helper functions
sim_subj_obs_hist_count <- function(dataset, condition = 0){
  
  dataset_obs <- dataset %>% 
    unnest(subj_obs) %>%
    ungroup() %>%
    filter(stimulation == condition) %>%
    select(obs_degree)
  
  breaks <- seq(-180, 180, 5)
  
  bincount <- hist(dataset_obs$obs_degree, breaks = breaks, plot = FALSE)$counts
  
  bincount_names <- glue("c{breaks[-1]}")
  
  names(bincount) <- bincount_names
  bincount_df <- data.frame(as.list(bincount))

  return(bincount_df)
  
}

make_quantmat <- function(sim_datasets, condition = 0){

  bincounts <- sim_datasets %>% 
  select(dataset) %>% 
  mutate(subj_hist_counts = map(dataset, sim_subj_obs_hist_count, condition)) %>% 
  select(-dataset) %>% 
  unnest(subj_hist_counts) %>%
  as_tibble()


  xvals <- seq(-177.5, 177.5, 5)
  probs <- seq(0.1,0.9,0.1)

  quantmat <- as.data.frame(matrix(NA, nrow=ncol(bincounts), ncol=length(probs)))
  names(quantmat) <- paste0("p",probs)

  quantmat <- cbind(quantmat, xvals)

  for (iQuant in 1:length(probs)){
   quantmat[,paste0("p",probs[iQuant])] <- as.numeric(summarise_all(bincounts, ~quantile(., probs[iQuant])))
  }
    
  return(quantmat)
}

sim

post_pred_fpath <- glue("{params$save_dir_str}/post_pred_obs.rds")

if (file.exists(post_pred_fpath)){
  
  post_pred_sim <- readRDS(post_pred_fpath)

}else{

  posterior_arranged <- 
    spread_draws(model_fit, `(b|sd)_.*`, regex = TRUE) %>%
    sample_n(2e3) %>%
    mutate(sim_num = 1:n(),
           nsubj = 1,
           nobs_per_cond = 500) %>%
    select(-contains(".")) %>%
    select(sim_num,
           alpha0_mu = b_circSD_intercept,
           alpha0_sigma = sd_subj__circSD_Intercept,
           alphaD_mu = b_circSD_stimulation,
           alphaD_sigma = sd_subj__circSD_stimulation,
           beta0_mu = b_theta_intercept,
           beta0_sigma = sd_subj__theta_Intercept,
           betaD_mu = b_theta_stimulation,
           betaD_sigma = sd_subj__theta_stimulation,
           nsubj,
           nobs_per_cond) 
  
  post_pred_sim <- posterior_arranged %>% 
    mutate(dataset = pmap(posterior_arranged, simulateData)) 
  
  saveRDS(post_pred_sim, post_pred_fpath)
  
}
post_predmean_fpath <- glue("{params$save_dir_str}/post_predmean_obs.rds")

if (file.exists(post_predmean_fpath)){
  
  post_predmean_sim <- readRDS(post_predmean_fpath)

}else{

  posterior_arranged <- 
    spread_draws(model_fit, `(b|sd)_.*`, regex = TRUE) %>%
    sample_n(2e3) %>%
    mutate(subj = 1,
           nobs_per_condition = 500) %>%
    select(-contains(".")) %>%
    select(subj_alpha0 = b_circSD_intercept,
           subj_alphaD = b_circSD_stimulation,
           subj_beta0 = b_theta_intercept,
           subj_betaD = b_theta_stimulation,
           nobs_per_condition) 
  
  post_predmean_sim <- posterior_arranged %>% 
    mutate(dataset = pmap(posterior_arranged, simulateData_subj),
           sim_num = 1:n())

  saveRDS(post_predmean_sim, post_predmean_fpath)
  
}

posterior data prediction, w/ group-level variance

single_sample <- 
  post_pred_sim %>%
  unnest(dataset) %>%
  unnest(subj_obs) %>%
  mutate(stimulation = as_factor(stimulation)) %>%
  group_by(sim_num, stimulation) %>%
  sample_n(1) %>%
  ungroup()

single_sample_quantile <-
  single_sample %>%
  group_by(stimulation) %>%
  summarise(lower = quantile(obs_degree, probs = 0.25),
            upper = quantile(obs_degree, probs = 0.75)) 


single_sample_edge_copies <- 
  rbind(
    single_sample %>% mutate(obs_degree = obs_degree - 360),
    single_sample,
    single_sample %>% mutate(obs_degree = obs_degree + 360)
    )
  
 
ggplot() + 
  geom_density(data = single_sample_edge_copies, 
               aes(x = obs_degree , fill = stimulation), alpha = 0.6, size = 0.8, bw = 15) +
  geom_vline(data = single_sample_quantile %>% pivot_longer(c(lower, upper)), 
             aes(xintercept = value, color = stimulation), size = 1, linetype = "dashed") + 
  coord_cartesian(xlim=c(-180, 180), expand = FALSE) +
  expand_limits(y = 0) + 
  scale_x_continuous(breaks = pretty_breaks(10)) +
  labs(title = "per-condition posterior pred dist (using group mean, sd) + \n50% intervals",
       subtitle = "[edge-corrected, improper density]") + 
  theme_cowplot()

single_sample %>% 
  ggplot() + 
  stat_ecdf(aes(x = obs_degree, color = stimulation), size = 1) + 
  geom_hline(yintercept = seq(0, 1, 0.25), linetype = "dashed", alpha = 0.2) +
  scale_x_continuous(breaks = pretty_breaks(12))  + 
  scale_y_continuous(breaks = seq(0, 1, 0.25)) + 
  coord_cartesian(xlim=c(-180, 180), expand = FALSE) + 
  expand_limits(y = 0) + 
  labs(y = "cumulative probability",
       title = "per-condition posterior pred CDF (using group mean, sd)") + 
  theme_cowplot()

# calculate quantile mats from each condition
quantmat_cond0 <- make_quantmat(post_pred_sim, 0)
quantmat_cond1 <- make_quantmat(post_pred_sim, 1)

 
c_light <- "#DCBCBC"
c_light_highlight <- "#C79999"
c_mid   <- "#B97C7C"
c_mid_highlight   <- "#A25050"
c_dark  <- "#8F2727"
c_dark_highlight  <- "#7C0000"


plot_grid(          
                                                                  
  ggplot(quantmat_cond0, aes(x = xvals)) + 
    geom_ribbon(aes(ymax = p0.9, ymin = p0.1), fill = c_light) + 
    geom_ribbon(aes(ymax = p0.8, ymin = p0.2), fill = c_light_highlight) + 
    geom_ribbon(aes(ymax = p0.7, ymin = p0.3), fill = c_mid) + 
    geom_ribbon(aes(ymax = p0.6, ymin = p0.4), fill = c_mid_highlight) + 
    geom_line(aes(y = p0.5), color = c_dark, size = 1) + 
    scale_x_continuous(breaks=pretty_breaks(10)) + 
    #coord_cartesian(ylim = c(0, 20)) + 
    labs(x = "error (degrees)", y = "count +/- quantile", subtitle = "without stimulation",
         title = "bin distribution over simulated datasets")
  ,

  ggplot(quantmat_cond1, aes(x = xvals)) + 
    geom_ribbon(aes(ymax = p0.9, ymin = p0.1), fill = c_light) + 
    geom_ribbon(aes(ymax = p0.8, ymin = p0.2), fill = c_light_highlight) + 
    geom_ribbon(aes(ymax = p0.7, ymin = p0.3), fill = c_mid) + 
    geom_ribbon(aes(ymax = p0.6, ymin = p0.4), fill = c_mid_highlight) + 
    geom_line(aes(y = p0.5), color = c_dark, size = 1) + 
    scale_x_continuous(breaks=pretty_breaks(10)) +
    #coord_cartesian(ylim = c(0, 20)) + 
    labs(x = "error (degrees)", y = "count +/- quantile", subtitle = "with stimulation")
  ,
  
  ncol = 1,
  align = "v"
)

posterior data prediction, w/o group-level variance

single_sample <- 
  post_predmean_sim %>%
  unnest(dataset) %>%
  mutate(stimulation = as_factor(stimulation)) %>%
  group_by(sim_num, stimulation) %>%
  sample_n(1) %>%
  ungroup()

single_sample_quantile <-
  single_sample %>%
  group_by(stimulation) %>%
  summarise(lower = quantile(obs_degree, probs = 0.25),
            upper = quantile(obs_degree, probs = 0.75)) 


single_sample_edge_copies <- 
  rbind(
    single_sample %>% mutate(obs_degree = obs_degree - 360),
    single_sample,
    single_sample %>% mutate(obs_degree = obs_degree + 360)
    )
  
 
ggplot() + 
  geom_density(data = single_sample_edge_copies, 
               aes(x = obs_degree , fill = stimulation), alpha = 0.6, size = 0.8, bw = 15) +
  geom_vline(data = single_sample_quantile %>% pivot_longer(c(lower, upper)), 
             aes(xintercept = value, color = stimulation), size = 1, linetype = "dashed") + 
  coord_cartesian(xlim=c(-180, 180), expand = FALSE) +
  expand_limits(y = 0) + 
  scale_x_continuous(breaks = pretty_breaks(10)) +
  labs(title = "per-condition posterior pred dist (using only group mean) + \n50% intervals",
       subtitle = "[edge-corrected, improper density]") + 
  theme_cowplot()

single_sample %>% 
  ggplot() + 
  stat_ecdf(aes(x = obs_degree, color = stimulation), size = 1) + 
  geom_hline(yintercept = seq(0, 1, 0.25), linetype = "dashed", alpha = 0.2) +
  scale_x_continuous(breaks = pretty_breaks(12))  + 
  scale_y_continuous(breaks = seq(0, 1, 0.25)) + 
  coord_cartesian(xlim=c(-180, 180), expand = FALSE) + 
  expand_limits(y = 0) + 
  labs(y = "cumulative probability",
       title = "per-condition posterior pred CDF (using only group mean)") + 
  theme_cowplot()

# calculate quantile mats from each condition
quantmat_cond0 <- make_quantmat(post_pred_sim, 0)
quantmat_cond1 <- make_quantmat(post_pred_sim, 1)

 
c_light <- "#DCBCBC"
c_light_highlight <- "#C79999"
c_mid   <- "#B97C7C"
c_mid_highlight   <- "#A25050"
c_dark  <- "#8F2727"
c_dark_highlight  <- "#7C0000"


plot_grid(          
                                                                  
  ggplot(quantmat_cond0, aes(x = xvals)) + 
    geom_ribbon(aes(ymax = p0.9, ymin = p0.1), fill = c_light) + 
    geom_ribbon(aes(ymax = p0.8, ymin = p0.2), fill = c_light_highlight) + 
    geom_ribbon(aes(ymax = p0.7, ymin = p0.3), fill = c_mid) + 
    geom_ribbon(aes(ymax = p0.6, ymin = p0.4), fill = c_mid_highlight) + 
    geom_line(aes(y = p0.5), color = c_dark, size = 1) + 
    scale_x_continuous(breaks=pretty_breaks(10)) + 
    #coord_cartesian(ylim = c(0, 20)) + 
    labs(x = "error (degrees)", y = "count +/- quantile", subtitle = "without stimulation",
         title = "bin distribution over simulated datasets")
  ,

  ggplot(quantmat_cond1, aes(x = xvals)) + 
    geom_ribbon(aes(ymax = p0.9, ymin = p0.1), fill = c_light) + 
    geom_ribbon(aes(ymax = p0.8, ymin = p0.2), fill = c_light_highlight) + 
    geom_ribbon(aes(ymax = p0.7, ymin = p0.3), fill = c_mid) + 
    geom_ribbon(aes(ymax = p0.6, ymin = p0.4), fill = c_mid_highlight) + 
    geom_line(aes(y = p0.5), color = c_dark, size = 1) + 
    scale_x_continuous(breaks=pretty_breaks(10)) +
    #coord_cartesian(ylim = c(0, 20)) + 
    labs(x = "error (degrees)", y = "count +/- quantile", subtitle = "with stimulation")
  ,
  
  ncol = 1,
  align = "v"
)